{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Multi-Head Attention\n",
    "\n",
    "Multi-Head Attention is an important part of all Transformer-based models.\n",
    "This tutorial shows how to write and optimize Multi-Head Attention."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import time\n",
    "import torch\n",
    "import ttnn\n",
    "from loguru import logger\n",
    "\n",
    "torch.manual_seed(0)\n",
    "\n",
    "device_id = 0\n",
    "device = ttnn.open_device(device_id=device_id)"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Write Multi-Head Attention with TT-NN\n",
    "\n",
    "Multi-head can be implemented in `torch` using the following six operations:\n",
    "\n",
    "1. `torch.matmul`\n",
    "2. `torch.add`\n",
    "3. `torch.reshape`\n",
    "4. `torch.permute`\n",
    "5. `torch.mul`\n",
    "6. `torch.softmax`\n",
    "\n",
    "`TT-NN` provides the same APIs for these operations. Multi-head Attention can be implemented similarly. \n",
    "Be mindful of the tensor layout except when using `TT-NN`."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def multi_head_attention(\n",
    "    hidden_states,\n",
    "    attention_mask,\n",
    "    query_weight,\n",
    "    query_bias,\n",
    "    key_weight,\n",
    "    key_bias,\n",
    "    value_weight,\n",
    "    value_bias,\n",
    "    output_weight,\n",
    "    output_bias,\n",
    "    *,\n",
    "    num_heads,\n",
    "):\n",
    "    fallback_reshape = ttnn.get_fallback_function(ttnn.reshape) \n",
    "       \n",
    "    batch_size, sequence_size, hidden_size = hidden_states.shape\n",
    "    head_size = hidden_size // num_heads\n",
    "\n",
    "    query = hidden_states @ query_weight\n",
    "    query = query + query_bias\n",
    "    query = ttnn.to_layout(query, layout=ttnn.ROW_MAJOR_LAYOUT)\n",
    "    query = fallback_reshape(query, (batch_size, sequence_size, num_heads, head_size))\n",
    "    query = ttnn.to_layout(query, layout=ttnn.TILE_LAYOUT)\n",
    "    query = ttnn.permute(query, (0, 2, 1, 3))\n",
    "\n",
    "    key = hidden_states @ key_weight\n",
    "    key = key + key_bias\n",
    "    key = ttnn.to_layout(key, layout=ttnn.ROW_MAJOR_LAYOUT)\n",
    "    key = fallback_reshape(key, (batch_size, sequence_size, num_heads, head_size))\n",
    "    key = ttnn.to_layout(key, layout=ttnn.TILE_LAYOUT)\n",
    "    key = ttnn.permute(key, (0, 2, 3, 1))\n",
    "\n",
    "    value = hidden_states @ value_weight\n",
    "    value = value + value_bias\n",
    "    value = ttnn.to_layout(value, layout=ttnn.ROW_MAJOR_LAYOUT)\n",
    "    value = fallback_reshape(value, (batch_size, sequence_size, num_heads, head_size))\n",
    "    value = ttnn.to_layout(value, layout=ttnn.TILE_LAYOUT)\n",
    "    value = ttnn.permute(value, (0, 2, 1, 3))\n",
    "\n",
    "    attention_scores = query @ key\n",
    "    attention_scores = attention_scores * (1 / (head_size**0.5))\n",
    "    attention_scores += attention_mask\n",
    "    attention_probs = ttnn.softmax(attention_scores, dim=-1, numeric_stable=False)\n",
    "\n",
    "    context_layer = attention_probs @ value\n",
    "    context_layer = ttnn.permute(context_layer, (0, 2, 1, 3))\n",
    "    context_layer = ttnn.to_layout(context_layer, layout=ttnn.ROW_MAJOR_LAYOUT)\n",
    "    context_layer = fallback_reshape(context_layer, (batch_size, sequence_size, hidden_size))\n",
    "    context_layer = ttnn.to_layout(context_layer, layout=ttnn.TILE_LAYOUT)\n",
    "\n",
    "    self_output = context_layer @ output_weight\n",
    "    self_output = self_output + output_bias\n",
    "\n",
    "    return self_output"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Now that the model is written, create input tensors to run and test it."
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Configuration"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "batch_size = 6\n",
    "sequence_size = 384\n",
    "num_heads = 16\n",
    "head_size = 64\n",
    "hidden_size = num_heads * head_size"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Initialize activations and weights with Torch:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "torch_hidden_states = torch.randn((batch_size, sequence_size, hidden_size), dtype=torch.bfloat16)\n",
    "torch_attention_mask = torch.randn((batch_size, 1, 1, sequence_size), dtype=torch.bfloat16)\n",
    "torch_query_weight = torch.randn((hidden_size, hidden_size), dtype=torch.bfloat16)\n",
    "torch_query_bias = torch.randn((hidden_size,), dtype=torch.bfloat16)\n",
    "torch_key_weight = torch.randn((hidden_size, hidden_size), dtype=torch.bfloat16)\n",
    "torch_key_bias = torch.randn((hidden_size,), dtype=torch.bfloat16)\n",
    "torch_value_weight = torch.randn((hidden_size, hidden_size), dtype=torch.bfloat16)\n",
    "torch_value_bias = torch.randn((hidden_size,), dtype=torch.bfloat16)\n",
    "torch_output_weight = torch.randn((hidden_size, hidden_size), dtype=torch.bfloat16)\n",
    "torch_output_bias = torch.randn((hidden_size,), dtype=torch.bfloat16)"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Convert activations and weights to TT-NN:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "hidden_states = ttnn.from_torch(torch_hidden_states, layout=ttnn.TILE_LAYOUT, device=device)\n",
    "attention_mask = ttnn.from_torch(torch_attention_mask, layout=ttnn.TILE_LAYOUT, device=device)\n",
    "query_weight = ttnn.from_torch(torch_query_weight, layout=ttnn.TILE_LAYOUT, device=device)\n",
    "query_bias = ttnn.from_torch(torch_query_bias, layout=ttnn.TILE_LAYOUT, device=device, memory_config=ttnn.L1_MEMORY_CONFIG)\n",
    "key_weight = ttnn.from_torch(torch_key_weight, layout=ttnn.TILE_LAYOUT, device=device)\n",
    "key_bias = ttnn.from_torch(torch_key_bias, layout=ttnn.TILE_LAYOUT, device=device, memory_config=ttnn.L1_MEMORY_CONFIG)\n",
    "value_weight = ttnn.from_torch(torch_value_weight, layout=ttnn.TILE_LAYOUT, device=device)\n",
    "value_bias = ttnn.from_torch(torch_value_bias, layout=ttnn.TILE_LAYOUT, device=device, memory_config=ttnn.L1_MEMORY_CONFIG)\n",
    "output_weight = ttnn.from_torch(torch_output_weight, layout=ttnn.TILE_LAYOUT, device=device)\n",
    "output_bias = ttnn.from_torch(torch_output_bias, layout=ttnn.TILE_LAYOUT, device=device, memory_config=ttnn.L1_MEMORY_CONFIG)"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Run the first iteration of Multi-Head Attention:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "start = time.time()\n",
    "multi_head_attention(\n",
    "    hidden_states,\n",
    "    attention_mask,\n",
    "    query_weight,\n",
    "    query_bias,\n",
    "    key_weight,\n",
    "    key_bias,\n",
    "    value_weight,\n",
    "    value_bias,\n",
    "    output_weight,\n",
    "    output_bias,\n",
    "    num_heads=num_heads,\n",
    ")\n",
    "end = time.time()\n",
    "duration = end - start"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "logger.info(f\"Multi-head attention ran in {duration} seconds for the first iteration\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Run a subsequent iteration of Multi-Head Attention:\n",
    "\n",
    "The first iteration of Multi-Head Attention will take several seconds to process. \n",
    "As `TT-NN` configures and compiles device code for operations on-the-fly, most execution time is spent compiling. \n",
    "\n",
    "Fortunately, configuration and compiled device code are stored in a program cache. \n",
    "This means that subsequent iterations will be significantly faster, about two orders of magnitude faster than the first iteration. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "start = time.time()\n",
    "output = multi_head_attention(\n",
    "    hidden_states,\n",
    "    attention_mask,\n",
    "    query_weight,\n",
    "    query_bias,\n",
    "    key_weight,\n",
    "    key_bias,\n",
    "    value_weight,\n",
    "    value_bias,\n",
    "    output_weight,\n",
    "    output_bias,\n",
    "    num_heads=num_heads,\n",
    ")\n",
    "end = time.time()\n",
    "duration = end - start"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "logger.info(f\"Multi-head attention ran in {duration} seconds for the subsequent iteration because of the program cache\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Write an optimized version of Multi-Head Attention:\n",
    "\n",
    "The optimized version of Multi-Head Attention can be written by:\n",
    "\n",
    "- Tilizing all of the tensors ahead of time.\n",
    "- Using more performant matmuls that fuse bias and specify the number of cores they execute on.\n",
    "- Putting every tensor into L1.\n",
    "- Using bfloat8_b data_type.\n",
    "- Using custom `ttnn.transformer` operations instead of `ttnn.permute` and `ttnn.reshape` operations.\n",
    "\n",
    "`ttnn.deallocate` calls are required, otherwise cores on the device will run out of the L1 memory."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def optimized_multi_head_attention(\n",
    "    hidden_states,\n",
    "    attention_mask,\n",
    "    fused_qkv_weight,\n",
    "    fused_qkv_bias,\n",
    "    self_output_weight,\n",
    "    self_output_bias,\n",
    "    *,\n",
    "    num_heads,\n",
    "    num_cores_x=12,\n",
    "):\n",
    "    batch_size, _, hidden_size = hidden_states.shape\n",
    "    head_size = hidden_size // num_heads\n",
    "    \n",
    "    hidden_states = ttnn.to_layout(hidden_states, ttnn.TILE_LAYOUT)\n",
    "\n",
    "    fused_qkv_output = ttnn.linear(\n",
    "        hidden_states,\n",
    "        fused_qkv_weight,\n",
    "        bias=fused_qkv_bias,\n",
    "        memory_config=ttnn.L1_MEMORY_CONFIG,\n",
    "        dtype=ttnn.bfloat8_b,\n",
    "        core_grid=ttnn.CoreGrid(y=batch_size, x=num_cores_x),\n",
    "    )\n",
    "\n",
    "    (\n",
    "        query,\n",
    "        key,\n",
    "        value,\n",
    "    ) = ttnn.transformer.split_query_key_value_and_split_heads(\n",
    "        fused_qkv_output,\n",
    "        memory_config=ttnn.L1_MEMORY_CONFIG,\n",
    "        num_heads=num_heads,\n",
    "    )\n",
    "    ttnn.deallocate(fused_qkv_output)\n",
    "\n",
    "    attention_scores = ttnn.matmul(\n",
    "        query,\n",
    "        key,\n",
    "        memory_config=ttnn.L1_MEMORY_CONFIG,\n",
    "        dtype=ttnn.bfloat16,\n",
    "        core_grid=ttnn.CoreGrid(y=batch_size, x=num_cores_x),\n",
    "    )\n",
    "    ttnn.deallocate(query)\n",
    "    ttnn.deallocate(key)\n",
    "\n",
    "    attention_probs = ttnn.transformer.attention_softmax_(attention_scores, attention_mask=attention_mask, head_size=head_size)\n",
    "\n",
    "    context_layer = ttnn.matmul(\n",
    "        attention_probs,\n",
    "        value,\n",
    "        memory_config=ttnn.L1_MEMORY_CONFIG,\n",
    "        dtype=ttnn.bfloat8_b,\n",
    "        core_grid=ttnn.CoreGrid(y=batch_size, x=num_cores_x),\n",
    "    )\n",
    "    ttnn.deallocate(attention_probs)\n",
    "\n",
    "    context_layer_after_concatenate_heads = ttnn.transformer.concatenate_heads(\n",
    "        context_layer,\n",
    "        memory_config=ttnn.L1_MEMORY_CONFIG,\n",
    "    )\n",
    "    ttnn.deallocate(context_layer)\n",
    "\n",
    "    self_output = ttnn.linear(\n",
    "        context_layer_after_concatenate_heads,\n",
    "        self_output_weight,\n",
    "        bias=self_output_bias,\n",
    "        memory_config=ttnn.L1_MEMORY_CONFIG,\n",
    "        dtype=ttnn.bfloat16,\n",
    "        core_grid=ttnn.CoreGrid(y=batch_size, x=num_cores_x),\n",
    "    )\n",
    "    ttnn.deallocate(context_layer_after_concatenate_heads)\n",
    "\n",
    "    return self_output"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Pre-process the parameters of the optimized model:\n",
    "\n",
    "1. Fuse QKV weights and biases.\n",
    "2. Reshape and tilize for optimized operations using preprocess_linear_weight and preprocess_linear_bias.\n",
    "3. Move to device."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from ttnn.model_preprocessing import (\n",
    "    preprocess_linear_bias,\n",
    "    preprocess_linear_weight,\n",
    ")\n",
    "\n",
    "torch_qkv_weight = torch.cat([torch_query_weight, torch_key_weight, torch_value_weight], dim=-1)\n",
    "torch_qkv_bias = torch.cat([torch_query_bias, torch_key_bias, torch_value_bias], dim=-1)\n",
    "\n",
    "qkv_weight = preprocess_linear_weight(torch_qkv_weight.T, dtype=ttnn.bfloat16)\n",
    "qkv_bias = preprocess_linear_bias(torch_qkv_bias, dtype=ttnn.bfloat16)\n",
    "output_weight = preprocess_linear_weight(torch_output_weight.T, dtype=ttnn.bfloat16)\n",
    "output_bias = preprocess_linear_bias(torch_output_bias, dtype=ttnn.bfloat16)\n",
    "\n",
    "qkv_weight = ttnn.to_device(qkv_weight, device)\n",
    "qkv_bias = ttnn.to_device(qkv_bias, device, memory_config=ttnn.L1_MEMORY_CONFIG)\n",
    "output_weight = ttnn.to_device(output_weight, device)\n",
    "output_bias = ttnn.to_device(output_bias, device, memory_config=ttnn.L1_MEMORY_CONFIG)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Run the first iteration of optimized Multi-Head Attention:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "start = time.time()\n",
    "hidden_states = ttnn.to_layout(hidden_states, ttnn.TILE_LAYOUT)\n",
    "optimized_output = optimized_multi_head_attention(\n",
    "    hidden_states,\n",
    "    attention_mask,\n",
    "    qkv_weight,\n",
    "    qkv_bias,\n",
    "    output_weight,\n",
    "    output_bias,\n",
    "    num_heads=num_heads,\n",
    ")\n",
    "end = time.time()\n",
    "duration = end - start"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "logger.info(f\"Optimized multi-head attention ran in {duration} seconds for the first iteration\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Run a subsequent iteration of optimized Multi-Head Attention:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "start = time.time()\n",
    "optimized_output = optimized_multi_head_attention(\n",
    "    hidden_states,\n",
    "    attention_mask,\n",
    "    qkv_weight,\n",
    "    qkv_bias,\n",
    "    output_weight,\n",
    "    output_bias,\n",
    "    num_heads=num_heads,\n",
    ")\n",
    "end = time.time()\n",
    "duration = end - start"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "logger.info(f\"Optimized multi-head attention ran in {duration} seconds for the subsequent iteration because of the program cache\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Note that the optimized multi-head attention is two orders of magnitude faster than the initial version."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Check that the output of the optimized version matches the output of the original implementation:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "torch_output = ttnn.to_torch(output)\n",
    "torch_optimized_output = ttnn.to_torch(optimized_output)\n",
    "\n",
    "assert torch.allclose(torch_output, torch_optimized_output)"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Close the device:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "ttnn.close_device(device)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Full Example and Output\n",
    "\n",
    "Lets put everything together in a complete example that can be run\n",
    "directly:\n",
    "\n",
    "[ttnn_multihead_attention.py](https://github.com/tenstorrent/tt-metal/tree/main/ttnn/tutorials/basic_python/ttnn_multihead_attention.py)\n",
    "\n",
    "Running the following script to generate output:\n",
    "\n",
    "``` console\n",
    "$ python3 $TT_METAL_HOME/ttnn/tutorials/basic_python/ttnn_multihead_attention.py\n",
    "2025-07-07 13:06:38.768 | info     |   SiliconDriver | Opened PCI device 7; KMD version: 1.34.0; API: 1; IOMMU: disabled (pci_device.cpp:198)\n",
    "2025-07-07 13:06:38.769 | info     |   SiliconDriver | Opened PCI device 7; KMD version: 1.34.0; API: 1; IOMMU: disabled (pci_device.cpp:198)\n",
    "2025-07-07 13:06:38.776 | info     |          Device | Opening user mode device driver (tt_cluster.cpp:190)\n",
    "2025-07-07 13:06:38.776 | info     |   SiliconDriver | Opened PCI device 7; KMD version: 1.34.0; API: 1; IOMMU: disabled (pci_device.cpp:198)\n",
    "2025-07-07 13:06:38.777 | info     |   SiliconDriver | Opened PCI device 7; KMD version: 1.34.0; API: 1; IOMMU: disabled (pci_device.cpp:198)\n",
    "2025-07-07 13:06:38.783 | info     |   SiliconDriver | Opened PCI device 7; KMD version: 1.34.0; API: 1; IOMMU: disabled (pci_device.cpp:198)\n",
    "2025-07-07 13:06:38.784 | info     |   SiliconDriver | Opened PCI device 7; KMD version: 1.34.0; API: 1; IOMMU: disabled (pci_device.cpp:198)\n",
    "2025-07-07 13:06:38.790 | info     |   SiliconDriver | Harvesting mask for chip 0 is 0x100 (NOC0: 0x100, simulated harvesting mask: 0x0). (cluster.cpp:282)\n",
    "2025-07-07 13:06:38.887 | info     |   SiliconDriver | Opened PCI device 7; KMD version: 1.34.0; API: 1; IOMMU: disabled (pci_device.cpp:198)\n",
    "2025-07-07 13:06:38.931 | info     |   SiliconDriver | Opening local chip ids/pci ids: {0}/[7] and remote chip ids {} (cluster.cpp:147)\n",
    "2025-07-07 13:06:38.942 | info     |   SiliconDriver | Software version 6.0.0, Ethernet FW version 6.14.0 (Device 0) (cluster.cpp:1039)\n",
    "2025-07-07 13:06:39.027 | info     |           Metal | AI CLK for device 0 is:   1000 MHz (metal_context.cpp:128)\n",
    "2025-07-07 13:06:39.603 | info     |           Metal | Initializing device 0. Program cache is enabled (device.cpp:428)\n",
    "2025-07-07 13:06:39.605 | warning  |           Metal | Unable to bind worker thread to CPU Core. May see performance degradation. Error Code: 22 (hardware_command_queue.cpp:74)\n",
    "2025-07-07 13:06:51.001 | INFO     | __main__:main:132 - Multi-head attention ran in 9.265338897705078 seconds for the first iteration\n",
    "2025-07-07 13:06:51.056 | INFO     | __main__:main:151 - Multi-head attention ran in 0.05480194091796875 seconds for the subsequent iteration because of the program cache\n",
    "2025-07-07 13:06:55.363 | INFO     | __main__:main:259 - Optimized multi-head attention ran in 4.2866740226745605 seconds for the first iteration\n",
    "2025-07-07 13:06:55.366 | INFO     | __main__:main:274 - Optimized multi-head attention ran in 0.002416849136352539 seconds for the subsequent iteration because of the program cache\n",
    "2025-07-07 13:06:55.417 | info     |           Metal | Closing mesh device 1 (mesh_device.cpp:488)\n",
    "2025-07-07 13:06:55.418 | info     |           Metal | Closing mesh device 0 (mesh_device.cpp:488)\n",
    "2025-07-07 13:06:55.418 | info     |           Metal | Closing device 0 (device.cpp:468)\n",
    "2025-07-07 13:06:55.418 | info     |           Metal | Disabling and clearing program cache on device 0 (device.cpp:783)\n",
    "2025-07-07 13:06:55.460 | info     |           Metal | Closing mesh device 1 (mesh_device.cpp:488)\n",
    "```"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.12"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
